import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from simple_tokenizer import SimpleTokenizer as _Tokenizer
import matplotlib
_tokenizer = _Tokenizer()
ret = _tokenizer.encode("a b c")
print(ret)
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
ans_path = "./data/glue/RTE/dev.tsv"
with open(ans_path) as f:
	data = [x.split("\t") for x in f.readlines()][1:]
df = pd.DataFrame(data,columns = ['idx','sentence1','sentence2','label'])#['bool','id1','id2','sentence1','sentence2'])
print([len(d) for d in df])
#df['label'] = #["equivalent" if a == 1 else "not_equivalent" for a in df["bool"]]
ans = [x.strip() for x in list(df["label"])]
s1 = list(df["sentence1"])
s2 = list(df["sentence2"])
#GOOD
good = {k:0 for k in range(len(ans))}
for seed in range(1,6):
	pred_path = f"./results/finetune/distil_main/bertxclip/Epoch01_XATTNBERT/RTE/lr-1e-4seed-{seed}/predict_results_rte.txt"
	with open(pred_path) as f:
		dfgood = pd.read_csv(pred_path,sep = "\t")
	print(len(df))
	predg = dfgood["prediction"]
	
	for i in range(len(ans)):
		print(ans[i],predg[i],ans[i] == predg[i])
		if ans[i] != predg[i]:
			pass#print(i)
		else:
			good[i] += 1
#print(good)
#s()
#BAD
bad = {k:0 for k in range(len(ans))}
for i in range(1,6):
	pred_path2 = f"./results/finetune/bert-base-uncased/RTE/lr-1e-4-seed-{seed}/predict_results_rte.txt"



	with open(pred_path2) as f:
		dfbad = pd.read_csv(pred_path2,sep = "\t")

	predb = dfbad["prediction"]
	for i in range(len(ans)):
		if ans[i] != predb[i]:
			pass#print(i)
		else:
			bad[i] += 1
diff = [(n,good[i]-bad[i]) for n,i in enumerate(range(len(ans)))]
diff = sorted(diff,key = lambda x : x[1],reverse = True)
print(diff)
advantage = 0

import benepar
import nltk
from nltk import word_tokenize
# benepar.download("benepar_en3")
# parser = benepar.Parser("benepar_en3")
import json

fs = ["/mnt/HDD/research/taiwanai/xattnbertfast/data/mscoco_minival.json",\
"/mnt/HDD/research/taiwanai/xattnbertfast/data/mscoco_train.json",\
"/mnt/HDD/research/taiwanai/xattnbertfast/data/mscoco_nominival.json"
]
corpus = []
for fpath in fs:
	with open(fpath,"r") as f:
		data = json.load(f)

	
	stopwords = {".",",","ourselves", "hers", "between", "yourself", "but", "again", "there", "about", "once", "during", "out", "very", "having", "with", "they", "own", "an", "be", "some", "for", "do", "its", "yours", "such", "into", "of", "most", "itself", "other", "off", "is", "s", "am", "or", "who", "as", "from", "him", "each", "the", "themselves", "until", "below", "are", "we", "these", "your", "his", "through", "don", "nor", "me", "were", "her", "more", "himself", "this", "down", "should", "our", "their", "while", "above", "both", "up", "to", "ours", "had", "she", "all", "no", "when", "at", "any", "before", "them", "same", "and", "been", "have", "in", "will", "on", "does", "yourselves", "then", "that", "because", "what", "over", "why", "so", "can", "did", "not", "now", "under", "he", "you", "herself", "has", "just", "where", "too", "only", "myself", "which", "those", "i", "after", "few", "whom", "t", "being", "if", "theirs", "my", "against", "a", "by", "doing", "it", "how", "further", "was", "here", "than"} 
	for s in list(stopwords):
		stopwords = stopwords.union(word_tokenize(s))
	for entry in data:
		for sentence in  entry["sentf"]["mscoco"]:
			corpus += [w for w in word_tokenize(sentence) if not w in stopwords] 
from collections import Counter
c = Counter(corpus)
from itertools import dropwhile
ng = 0
for key, count in dropwhile(lambda key_count: key_count[1] >= 100, c.most_common()):
    ng += c[key]
    del c[key]
#c.most_common(len(c))
g= sum(list(c.values()))
print(c)
print(len(c))
print("MSCOCO grounded_ratio",g/(g+ng))

compare = {}
compare["len1"] = {"pos":[],"neg":[],"same":[]}
compare["len2"] = {"pos":[],"neg":[],"same":[]}
compare["grounded"] = {"pos":0,"neg":0,"same":0}
compare["nongrounded"] = {"pos":0,"neg":0,"same":0}
compare["isbetter"] = {"pos":0,"neg":0,"same":0}
for i in range(len(diff)):
	advantage += diff[i][1]
	print("**",diff[i][1])
	print(s1[diff[i][0]])
		
	#tree = parser.parse(input_sentence)
	#print(tree)
	#s()
	if diff[i][1] > 0:
		switch = "pos"
		compare["isbetter"]["pos"] += 1
		weight = diff[i][1]
	elif diff[i][1] < 0:
		switch = "neg"
		compare["isbetter"]["neg"] += 1
		weight = -1*diff[i][1]
	else:
		switch = "same"
		compare["isbetter"]["same"] += 1
		weight = 1
	#for i in range(weight):
	compare["len1"][switch]+= [len(word_tokenize(s1[diff[i][0]]))]
	compare["len2"][switch]+= [len(word_tokenize(s2[diff[i][0]]))]
	for w in word_tokenize(s1[diff[i][0]]):
		if w in stopwords:
			pass
		elif w in c:
			compare["grounded"][switch] += weight
		else:
			compare["nongrounded"][switch] += weight
	for w in word_tokenize(s2[diff[i][0]]):
		if w in stopwords:
			pass
		elif w in c:
			compare["grounded"][switch] += weight
		else:
			compare["nongrounded"][switch] += weight
	print(ans[diff[i][0]])
print(advantage)

print(compare["isbetter"])
print(sum(compare["len1"]["pos"])/len(compare["len1"]["pos"]))
print(sum(compare["len1"]["neg"])/len(compare["len1"]["neg"]))
print(sum(compare["len2"]["pos"])/len(compare["len2"]["pos"]))
print(sum(compare["len2"]["neg"])/len(compare["len2"]["neg"]))
print(compare["grounded"],compare["nongrounded"])

x = np.array([
    0,1
])
l_percent = ["RTE Entry Type","Visually Grounded Ratio"]
l_len = ["Length of Sentence1","Length of Sentence2"]

poss_percent = [compare["isbetter"]["pos"]/(sum(compare["isbetter"].values()))*100,\
	compare["grounded"]["pos"]/(compare["grounded"]["pos"]+compare["nongrounded"]["pos"])*100]
poss_len = [
	sum(compare["len1"]["pos"])/len(compare["len1"]["pos"]),\
	sum(compare["len2"]["pos"])/len(compare["len2"]["pos"])
]
same_percent = [compare["isbetter"]["same"]/(sum(compare["isbetter"].values()))*100,\
	compare["grounded"]["same"]/(compare["grounded"]["same"]+compare["nongrounded"]["same"])*100]
same_len = [
		sum(compare["len1"]["same"])/len(compare["len1"]["same"]),\
	sum(compare["len2"]["same"])/len(compare["len2"]["same"])
]
negs_percent = [compare["isbetter"]["neg"]/(sum(compare["isbetter"].values()))*100,\
	compare["grounded"]["neg"]/(compare["grounded"]["neg"]+compare["nongrounded"]["neg"])*100]
negs_len = [
		sum(compare["len1"]["neg"])/len(compare["len1"]["neg"]),\
	sum(compare["len2"]["neg"])/len(compare["len2"]["neg"])
]
#plt.figure()
plt.rcParams.update({'font.size': 8})
fig,(ax1,ax2) = plt.subplots(2, 1,figsize = (3.13,1.8),gridspec_kw={'height_ratios': [0.545, 0.455]})
fig.subplots_adjust(hspace=.35)

b1p = ax1.bar(x, poss_percent, width=0.2, color='blue', align='center')
b_placeholder = ax1.bar(x-0.2,poss_percent, width=0.2, color='purple', alpha=0)

b1same = ax1.bar(x-0.2, same_percent, width=0.2, color='green', align='center')
b1n = ax1.bar(x+0.2, negs_percent, width=0.2, color='r', align='center')
b2p = ax2.bar(x, poss_len, width=0.2, color='b', align='center')
b2same = ax2.bar(x-0.2, same_len, width=0.2, color='g', align='center')
b2n = ax2.bar(x+0.2, negs_len, width=0.2, color='r', align='center')
ax1.set_xticks(x)
legend = ax1.legend([b1same,b_placeholder,b1p ,b1n], ['on par','', 'improved','worsened'],ncol=2,frameon=False,borderpad=0.2,labelspacing=0.2,columnspacing=0.4)

legend.get_frame().set_alpha(None)
legend.get_frame().set_facecolor('none')
# legend2 = ax1.legend([b1n], ['worsened'],ncol=1,frameon=False)
# legend.get_frame().set_alpha(None)
# legend.get_frame().set_facecolor('none')
ax2.set_xticks(x)
ax1.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter())
ax1.set_ylim([0,100])
ax2.set_ylim([0,80])
# Set the tick labels
ax1.set_xticklabels(l_percent)
ax2.set_xticklabels(l_len)
#ax1.set_ylabel("percentage",labelpad = -8)
ax2.set_ylabel("words")
def autolabel(rects, p, percent):
    """
    Attach a text label above each bar displaying its height
    """

    #percent = ["%","","","%"]
    for n,rect in enumerate(rects):
        height = rect.get_height()
        p.text(rect.get_x() + rect.get_width()/2., height+0.1,
                '%.1f' % height ,
                ha='center', va='bottom')

autolabel(b1p,ax1,True)
autolabel(b1same,ax1,True)
autolabel(b1n,ax1,True)
autolabel(b2p,ax2,False)
autolabel(b2same,ax2,False)
autolabel(b2n,ax2,False)
#fig.tight_layout()
plt.show()
fig.savefig("./src/analysis/analysis.pdf",bbox_inches = "tight")
fig.savefig("./src/analysis/analysis.png",bbox_inches = "tight")
plt.clf()
s()
# print(diff)
# for i in range(-1,-6,-1):
# 	print("**",diff[i][1])
# 	print(s1[diff[i][0]])
# 	print(s2[diff[i][0]])
# 	print(ans[diff[i][0]])
# g2b = 0
# b2g = 0
# for i in range(len(ans)):
# 	if ans[i] != predb[i] and ans[i] == predg[i]:
# 		b2g += 1
# 	if ans[i] != predg[i] and ans[i] == predb[i]:
# 		g2b += 1
# 	else:
# 		pass
# print("b2g : ",b2g/len(ans))
# print("g2b : ",g2b/len(ans))
